{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Quadrature\n",
    "\n",
    "This is a notebook that illustrates differences in the quadrature methods which are central to the evaluation of\n",
    "\n",
    "$$\n",
    "\\int_0^{\\pi} f(\\theta) \\sin \\theta \\; \\mathrm{d}\\theta = \\int_{-1}^{1} f(\\cos \\theta) \\; \\mathrm{d}\\cos \\theta.\n",
    "$$\n",
    "\n",
    "Quadrature is used to compute the projection onto the associated Legendre functions in `paddle-harmonics`.\n",
    "\n",
    "In order to illustrate how interpolation and quadrature affect the error in the computation of the SHT, this notebook contains example for both errors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import paddle\n",
    "import scipy\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_theta = 80"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import paddle_harmonics\n",
    "from paddle_harmonics.quadrature import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test interpolation\n",
    "\n",
    "we first assess the interpolation onto the quadrature nodes:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# interpolation - careful - this breaks if points align (for non-periodic signals)\n",
    "def interpolate(t, tq, f):\n",
    "    j = np.searchsorted(t, tq) - 1\n",
    "    d = paddle.to_tensor( (tq - t[j]) / np.diff(t)[j] )\n",
    "    j = paddle.to_tensor(j)\n",
    "    interp = paddle.lerp(paddle.to_tensor(f[j]), paddle.to_tensor(f[j+1]), d)\n",
    "    # interp = f[j] + (f[j+1] - f[j]) * (tq - t[j]) / np.diff(t)[j]\n",
    "    # print(d)\n",
    "    # print(f[j+1] - f[j])\n",
    "    # print(j)\n",
    "    return interp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cost_lg, wlg = legendre_gauss_weights(n_theta, -1, 1)\n",
    "# cost_lg, wlg = lobatto_weights(n_theta, -1, 1)\n",
    "tq = np.flip(np.arccos(cost_lg))\n",
    "teq = np.linspace(0, np.pi, n_theta, dtype=np.float64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(teq, '.')\n",
    "plt.plot(tq, '.')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "test interpolation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f = lambda t : np.cos(4*t)\n",
    "# f = lambda t : 1 / (1 + 25 * (2*(t-np.pi/2)/np.pi)**2)\n",
    "# f = lambda t : 1 / (1 + 25 * np.cos(t)**2)\n",
    "# f = lambda t : t**5 - 3*t**2 - 2*t + 1.0\n",
    "\n",
    "interp = interpolate(teq, tq, f(teq))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "interp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f(teq)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plt.plot(teq, f(teq), '.-', label=\"reference\")\n",
    "plt.plot(tq, f(tq), '.-', label=\"reference\")\n",
    "# plt.plot(tq, interp, '.-', label=\"interpolated\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test quadrature with associated Legendre polynomials\n",
    "\n",
    "let us test different quadrature modes:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def precompute_legpoly(m_max, l_max, x):\n",
    "    \"\"\"\n",
    "    Computes the values of P^m_n(\\cos \\theta) at the positions specified by x (theta)\n",
    "    The resulting tensor has shape (m_max, l_max, len(x))\n",
    "    \"\"\"\n",
    "\n",
    "    # compute the tensor P^m_n:\n",
    "    pct = np.zeros((m_max, l_max, len(x)), dtype=np.float64)\n",
    "\n",
    "    sinx = np.sin(x)\n",
    "    cosx = np.cos(x)\n",
    "\n",
    "    a = lambda m, l: np.sqrt((4*l**2 - 1) / (l**2 - m**2))\n",
    "    b = lambda m, l: -1 * np.sqrt((2*l+1)/(2*l-3)) * np.sqrt(((l-1)**2 - m**2)/(l**2 - m**2))\n",
    "\n",
    "    # start by populating the diagonal and the second higher diagonal\n",
    "    amm = np.sqrt( 1. / (4 * np.pi) )\n",
    "    pct[0,0,:] = amm\n",
    "    pct[0,1,:] = a(0, 1) * cosx * amm\n",
    "    for m in range(1, min(m_max, l_max)):\n",
    "        pct[m,m,:] = -1*np.sqrt( (2*m+1) / (2*m) ) * pct[m-1,m-1,:] * sinx\n",
    "        if m + 1 < l_max:\n",
    "            pct[m,m+1,:] = a(m, m+1) * cosx * pct[m,m,:]\n",
    "\n",
    "    # fill the remaining values on the upper triangle\n",
    "    for m in range(0, m_max):\n",
    "        for l in range(m+2, l_max):\n",
    "            pct[m,l,:] = a(m,l) * cosx * pct[m,l-1,:] + b(m,l) * pct[m,l-2,:]\n",
    "\n",
    "    return paddle.to_tensor(pct)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let us plot the Legendre polynomials:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = 0\n",
    "\n",
    "pct = np.sqrt(2 * np.pi) * precompute_legpoly(n_theta, n_theta, teq)\n",
    "\n",
    "fig, ax = plt.subplots(1, 1)\n",
    "for l in range(6):\n",
    "    ax.plot(np.cos(teq), pct[0, l].numpy())\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def project(t, w, f, mmax=None):\n",
    "    m = 0\n",
    "    if mmax == None:\n",
    "        mmax = len(t)\n",
    "\n",
    "    weights = paddle.to_tensor(w)\n",
    "    pct = np.sqrt(2 * np.pi) * precompute_legpoly(mmax, mmax, t)\n",
    "    weights = paddle.einsum('mlk,k->mlk', pct, weights)\n",
    "\n",
    "    proj = paddle.einsum('...k,lk->...l', paddle.to_tensor(f), weights[m])\n",
    "    rec = paddle.einsum('...l, lk->...k', proj, pct[m] )\n",
    "    return rec"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "let us compare the accuracy of the different projection methods:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = np.linspace(0, np.pi, n_theta)\n",
    "plt.plot(t, f(t), label=\"reference\")\n",
    "\n",
    "for quadrature in [legendre_gauss_weights, lobatto_weights, clenshaw_curtiss_weights, fejer2_weights]:\n",
    "    cost, wq = quadrature(n_theta, -1, 1)\n",
    "    tq = np.flip(np.arccos(cost))\n",
    "\n",
    "    out = project(tq, wq, f(tq))\n",
    "\n",
    "    plt.plot(tq, out.numpy(), '.-', label=quadrature.__name__)\n",
    "\n",
    "plt.legend(loc='lower left')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for quadrature in [legendre_gauss_weights, lobatto_weights, clenshaw_curtiss_weights]:\n",
    "    cost, wq = quadrature(n_theta, -1, 1)\n",
    "    tq = np.flip(np.arccos(cost))\n",
    "\n",
    "    out = project(tq, wq, f(tq))\n",
    "    # print(np.abs(out - f(tq)))\n",
    "\n",
    "    plt.semilogy(tq, out.numpy() - f(tq), '.-', label=quadrature.__name__)\n",
    "\n",
    "plt.legend(loc='lower left')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let us now add interpolation into the mix to evaluate performance with interpolation taken into account. For this particular case, we will assume that the data is given to us on an equidistant grid."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = np.linspace(0, np.pi, n_theta)\n",
    "ref = f(t)\n",
    "plt.plot(t, ref, label=\"reference\")\n",
    "\n",
    "for quadrature in [legendre_gauss_weights, lobatto_weights, clenshaw_curtiss_weights]:\n",
    "    cost, wq = quadrature(n_theta, -1, 1)\n",
    "    tq = np.flip(np.arccos(cost))\n",
    "\n",
    "    if quadrature == lobatto_weights or quadrature == legendre_gauss_weights:\n",
    "        f_interp = interpolate(t, tq, ref)\n",
    "        mmax = len(tq)\n",
    "    else:\n",
    "        f_interp = ref\n",
    "        mmax = len(tq)\n",
    "\n",
    "    out = project(tq, wq, f_interp, mmax=mmax)\n",
    "\n",
    "    plt.plot(tq, out, '.-', label=quadrature.__name__)\n",
    "\n",
    "plt.legend(loc='lower left')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "again, let us plot the overall error, this time including the interpolation error:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = np.linspace(0, np.pi, n_theta)\n",
    "ref = f(t)\n",
    "\n",
    "fig, ax = plt.subplots(2, 1)\n",
    "\n",
    "for quadrature in [legendre_gauss_weights, lobatto_weights, clenshaw_curtiss_weights]:\n",
    "    cost, wq = quadrature(n_theta, -1, 1)\n",
    "    tq = np.flip(np.arccos(cost))\n",
    "\n",
    "    if quadrature == lobatto_weights or quadrature == legendre_gauss_weights:\n",
    "        f_interp = interpolate(t, tq, ref)\n",
    "        mmax = len(tq)\n",
    "    else:\n",
    "        f_interp = ref\n",
    "        mmax = len(tq)\n",
    "\n",
    "    out = project(tq, wq, f_interp, mmax=mmax)\n",
    "\n",
    "    ax[0].semilogy(tq, out.numpy() - f(tq), '.-', label=quadrature.__name__)\n",
    "    if isinstance(f_interp, paddle.Tensor):\n",
    "        f_interp = f_interp.numpy()\n",
    "    ax[1].semilogy(tq, f_interp - f(tq), '.-', label=quadrature.__name__)\n",
    "\n",
    "ax[0].set_title(\"Projection error after interpolation\")\n",
    "ax[1].set_title(\"Interpolation error\")\n",
    "# ax[0].legend(loc='lower left')\n",
    "# ax[1].legend(loc='lower left')\n",
    "fig.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "we can see that the interpolation dominates when we interpolate the solution. For this reason, it is reasonable t choose Clenshaw-Curtiss quadrature in scenarios where we expect the interpolation error to dominate."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.6 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6 (main, Nov 14 2022, 16:10:14) [GCC 11.3.0]"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
